import numpy as np
import torch
import torch.nn.functional as F
import copy

def init_key_query_args(args, use_broadcast_mask=False): # key and query networks use the same dimensions
    # value args only differ slightly
    query_args = copy.deepcopy(args)
    # by initializing both num_inputs and object dim, num_outputs and output_dim, handles both mlp and conv
    query_args.num_inputs = args.embed_dim + (int(use_broadcast_mask) * args.factor_net.append_broadcast_mask)
    query_args.object_dim = args.embed_dim + (int(use_broadcast_mask) * args.factor_net.append_broadcast_mask)
    query_args.num_outputs = args.mask_attn.model_dim * args.mask_attn.num_heads
    query_args.output_dim = args.mask_attn.model_dim * args.mask_attn.num_heads
    query_args.factor_net.num_pair_layers = 1
    query_args.hidden_sizes = list() if args.mask_attn.no_hidden else [hs*args.mask_attn.num_heads for hs in args.hidden_sizes]
    query_args.aggregate_final = False
    query_args.activation_final = "none"
    return query_args

def init_final_args(args):
    # the final layer after performking multi head attention
    concatenated_values = args.mask_attn.merge_function == "cat"
    final_args = copy.deepcopy(args)
    final_args.num_inputs = args.mask_attn.model_dim * args.mask_attn.num_heads if concatenated_values else args.mask_attn.model_dim
    final_args.object_dim = args.mask_attn.model_dim * args.mask_attn.num_heads if concatenated_values else args.mask_attn.model_dim
    final_args.num_outputs = args.embed_dim
    final_args.output_dim = args.embed_dim
    final_args.dropout = args.mask_attn.attention_dropout
    final_args.use_layer_norm = args.mask_attn.attention_layer_norm
    final_args.hidden_sizes = [(hs*args.mask_attn.num_heads if concatenated_values else hs) for hs in args.factor_net.final_layers] # these should be wide enough to support model_dim * self.num_heads
    final_args.activation_final = final_args.activation
    return final_args


def evaluate_key_query(softmax, keys, queries, mask, single_key=False, gumbel=-1, renormalize=False, use_sigmoid=False):
    # computes the key query comparison
    # applies mask, assuming that mask is probabilistic 
    # Batch, heads, num queries, embed x Batch, heads, embed, num_keys 
    weights = torch.matmul(queries, keys).transpose(-1,-2) # batch, heads, num_keys, num_queries
    if use_sigmoid: # sigmoid logic preserves magnitude across dimensions if there is not an upcoming summation
        # if we are using stochastic bernoulli weights, apply the bernoulli operation here
        if gumbel > 0: weights = F.gumbel_bernoulli(weights, tau = gumbel, hard = False)
        else: weights = F.sigmoid(weights / np.sqrt(queries.shape[-1])) # uses a sigmoid
    else:
        # if we are using a stochastic softmax attention weights, apply the gumbel softmax instead here
        if gumbel > 0: weights = F.gumbel_softmax(weights, tau = gumbel, hard = False, dim=-1) # does not change shape, but assumes queries in last layer
        else: weights = softmax(weights / np.sqrt(queries.shape[-1])) # softmax expected along dim=-1, values in 0,1
    # masks override the weights, then renormalizes TODO: use -inf instead?
    # if valid is not None: print(valid.shape, mask.shape, queries.shape, keys.shape, weights.shape)
    if mask is not None:
        # print(queries.shape, keys.shape, weights.shape, mask.shape)
        if len(mask.shape) == 2: mask = mask.unsqueeze(-2)# check if mask includes keys
        weights = weights * torch.broadcast_to(mask.unsqueeze(1), (weights.shape[0], weights.shape[1], 1, mask.shape[-1])) # Batch, heads, num_keys, num queries x Batch, heads, 1, num_queries
    if renormalize: weights = weights / (weights.sum(axis=-1).unsqueeze(-1) + 1e-4) # renormalizing along queries after zeroing out
    if single_key: weights = weights[:, :, 0]
    return weights # batch x heads x keys (if not single key) x queries 

def mask_query(queries, mask, valid, single_key = False):
    if mask is not None: queries = queries * (mask.unsqueeze(-1) if single_key else mask)
    if valid is not None: queries = queries * valid.unsqueeze(-1)
    return queries